/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

const fs = require('fs');
const path = require('path');

const argparse = require('argparse');
const canvas = require('canvas');
const tf = require('@tensorflow/tfjs');
const synthesizer = require('./synthetic_images');

const CANVAS_SIZE = 224;  // Matches the input size of MobileNet.

// Name prefixes of layers that will be unfrozen during fine-tuning.
const topLayerGroupNames = ['conv_pw_9', 'conv_pw_10', 'conv_pw_11'];

// Name of the layer that will become the top layer of the truncated base.
const topLayerName =
    `${topLayerGroupNames[topLayerGroupNames.length - 1]}_relu`;

// Used to scale the first column (0-1 shape indicator) of `yTrue`
// in order to ensure balanced contributions to the final loss value
// from shape and bounding-box predictions.
const LABEL_MULTIPLIER = [CANVAS_SIZE, 1, 1, 1, 1];

/**
 * Custom loss function for object detection.
 *
 * The loss function is a sum of two losses
 * - shape-class loss, computed as binaryCrossentropy and scaled by
 *   `classLossMultiplier` to match the scale of the bounding-box loss
 *   approximatey.
 * - bounding-box loss, computed as the meanSquaredError between the
 *   true and predicted bounding boxes.
 * @param {tf.Tensor} yTrue True labels. Shape: [batchSize, 5].
 *   The first column is a 0-1 indicator for whether the shape is a triangle
 *   (0) or a rectangle (1). The remaining for columns are the bounding boxes
 *   for the target shape: [left, right, top, bottom], in unit of pixels.
 *   The bounding box values are in the range [0, CANVAS_SIZE).
 * @param {tf.Tensor} yPred Predicted labels. Shape: the same as `yTrue`.
 * @return {tf.Tensor} Loss scalar.
 */
function customLossFunction(yTrue, yPred) {
  return tf.tidy(() => {
    // Scale the the first column (0-1 shape indicator) of `yTrue` in order
    // to ensure balanced contributions to the final loss value
    // from shape and bounding-box predictions.
    return tf.metrics.meanSquaredError(yTrue.mul(LABEL_MULTIPLIER), yPred);
  });
}

/**
 * Loads MobileNet, removes the top part, and freeze all the layers.
 *
 * The top removal and layer freezing are preparation for transfer learning.
 *
 * Also gets handles to the layers that will be unfrozen during the fine-tuning
 * phase of the training.
 *
 * @return {tf.Model} The truncated MobileNet, with all layers frozen.
 */
async function loadTruncatedBase() {
  // TODO(cais): Add unit test.
  const mobilenet = await tf.loadLayersModel(
      'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');

  // Return a model that outputs an internal activation.
  const fineTuningLayers = [];
  const layer = mobilenet.getLayer(topLayerName);
  const truncatedBase =
      tf.model({inputs: mobilenet.inputs, outputs: layer.output});
  // Freeze the model's layers.
  for (const layer of truncatedBase.layers) {
    layer.trainable = false;
    for (const groupName of topLayerGroupNames) {
      if (layer.name.indexOf(groupName) === 0) {
        fineTuningLayers.push(layer);
        break;
      }
    }
  }

  tf.util.assert(
      fineTuningLayers.length > 1,
      `Did not find any layers that match the prefixes ${topLayerGroupNames}`);
  return {truncatedBase, fineTuningLayers};
}

/**
 * Build a new head (i.e., output sub-model) that will be connected to
 * the top of the truncated base for object detection.
 *
 * @param {tf.Shape} inputShape Input shape of the new model.
 * @returns {tf.Model} The new head model.
 */
function buildNewHead(inputShape) {
  const newHead = tf.sequential();
  newHead.add(tf.layers.flatten({inputShape}));
  newHead.add(tf.layers.dense({units: 200, activation: 'relu'}));
  // Five output units:
  //   - The first is a shape indictor: predicts whether the target
  //     shape is a triangle or a rectangle.
  //   - The remaining four units are for bounding-box prediction:
  //     [left, right, top, bottom] in the unit of pixels.
  newHead.add(tf.layers.dense({units: 5}));
  return newHead;
}

/**
 * Builds object-detection model from MobileNet.
 *
 * @returns {[tf.Model, tf.layers.Layer[]]}
 *   1. The newly-built model for simple object detection.
 *   2. The layers that can be unfrozen during fine-tuning.
 */
async function buildObjectDetectionModel() {
  const {truncatedBase, fineTuningLayers} = await loadTruncatedBase();

  // Build the new head model.
  const newHead = buildNewHead(truncatedBase.outputs[0].shape.slice(1));
  const newOutput = newHead.apply(truncatedBase.outputs[0]);
  const model = tf.model({inputs: truncatedBase.inputs, outputs: newOutput});

  return {model, fineTuningLayers};
}

(async function main() {
  // Data-related settings.
  const numCircles = 10;
  const numLines = 10;

  const parser = new argparse.ArgumentParser();
  parser.addArgument('--gpu', {
    action: 'storeTrue',
    help: 'Use tfjs-node-gpu for training (required CUDA and CuDNN)'
  });
  parser.addArgument(
      '--numExamples',
      {type: 'int', defaultValue: 2000, help: 'Number of training exapmles'});
  parser.addArgument('--validationSplit', {
    type: 'float',
    defaultValue: 0.15,
    help: 'Validation split to be used during training'
  });
  parser.addArgument('--batchSize', {
    type: 'int',
    defaultValue: 128,
    help: 'Batch size to be used during training'
  });
  parser.addArgument('--initialTransferEpochs', {
    type: 'int',
    defaultValue: 100,
    help: 'Number of training epochs in the initial transfer ' +
        'learning (i.e., 1st) phase'
  });
  parser.addArgument('--fineTuningEpochs', {
    type: 'int',
    defaultValue: 100,
    help: 'Number of training epochs in the fine-tuning (i.e., 2nd) phase'
  });
  parser.addArgument('--logDir', {
    type: 'string',
    help: 'Optional tensorboard log directory, to which the loss ' +
    'values will be logged during model training.'
  });
  parser.addArgument('--logUpdateFreq', {
    type: 'string',
    defaultValue: 'batch',
    optionStrings: ['batch', 'epoch'],
    help: 'Frequency at which the loss will be logged to tensorboard.'
  });
  const args = parser.parseArgs();

  let tfn;
  if (args.gpu) {
    console.log('Training using GPU.');
    tfn = require('@tensorflow/tfjs-node-gpu');
  } else {
    console.log('Training using CPU.');
    tfn = require('@tensorflow/tfjs-node');
  }

  const modelSaveURL = 'file://./dist/object_detection_model';

  const tBegin = tf.util.now();
  console.log(`Generating ${args.numExamples} training examples...`);
  const synthDataCanvas = canvas.createCanvas(CANVAS_SIZE, CANVAS_SIZE);
  const synth =
      new synthesizer.ObjectDetectionImageSynthesizer(synthDataCanvas, tf);
  const {images, targets} =
      await synth.generateExampleBatch(args.numExamples, numCircles, numLines);

  const {model, fineTuningLayers} = await buildObjectDetectionModel();
  model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(5e-3)});
  model.summary();

  // Initial phase of transfer learning.
  console.log('Phase 1 of 2: initial transfer learning');
  await model.fit(images, targets, {
    epochs: args.initialTransferEpochs,
    batchSize: args.batchSize,
    validationSplit: args.validationSplit,
    callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
      updateFreq: args.logUpdateFreq
    })
  });

  // Fine-tuning phase of transfer learning.
  // Unfreeze layers for fine-tuning.
  for (const layer of fineTuningLayers) {
    layer.trainable = true;
  }
  model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(2e-3)});
  model.summary();

  // Do fine-tuning.
  // The batch size is reduced to avoid CPU/GPU OOM. This has
  // to do with the unfreezing of the fine-tuning layers above,
  // which leads to higher memory consumption during backpropagation.
  console.log('Phase 2 of 2: fine-tuning phase');
  await model.fit(images, targets, {
    epochs: args.fineTuningEpochs,
    batchSize: args.batchSize / 2,
    validationSplit: args.validationSplit,
    callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
      updateFreq: args.logUpdateFreq
    })
  });

  // Save model.
  // First make sure that the base directory dists.
  const modelSavePath = modelSaveURL.replace('file://', '');
  const dirName = path.dirname(modelSavePath);
  if (!fs.existsSync(dirName)) {
    fs.mkdirSync(dirName);
  }
  await model.save(modelSaveURL);
  console.log(`Model training took ${(tf.util.now() - tBegin) / 1e3} s`);
  console.log(`Trained model is saved to ${modelSaveURL}`);
  console.log(
      `\nNext, run the following command to test the model in the browser:`);
  console.log(`\n  yarn watch`);
})();
